{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Exploration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of images: 335\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "path2test=\"./data/test_set/\"\n",
    "imgsList=[pp for pp in os.listdir(path2test) if \"Annotation\" not in pp]\n",
    "print(\"number of images:\", len(imgsList))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['043_HC.png', '011_HC.png', '303_HC.png', '091_HC.png'],\n",
       "      dtype='<U10')"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.random.seed(2019)\n",
    "rndImgs=np.random.choice(imgsList,4)\n",
    "rndImgs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class SegNet(nn.Module):\n",
    "    def __init__(self, params):\n",
    "        super(SegNet, self).__init__()\n",
    "        \n",
    "        C_in, H_in, W_in=params[\"input_shape\"]\n",
    "        init_f=params[\"initial_filters\"] \n",
    "        num_outputs=params[\"num_outputs\"] \n",
    "\n",
    "        self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3,stride=1,padding=1)\n",
    "        self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3,stride=1,padding=1)\n",
    "        self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3,padding=1)\n",
    "        self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3,padding=1)\n",
    "        self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3,padding=1)\n",
    "\n",
    "        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n",
    "\n",
    "        self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3,padding=1)\n",
    "        self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3,padding=1)\n",
    "        self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3,padding=1)\n",
    "        self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3,padding=1)\n",
    "\n",
    "        self.conv_out = nn.Conv2d(init_f, num_outputs , kernel_size=3,padding=1)    \n",
    "    \n",
    "    def forward(self, x):\n",
    "        \n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "\n",
    "        x = F.relu(self.conv4(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "\n",
    "        x = F.relu(self.conv5(x))\n",
    "\n",
    "        x=self.upsample(x)\n",
    "        x = F.relu(self.conv_up1(x))\n",
    "\n",
    "        x=self.upsample(x)\n",
    "        x = F.relu(self.conv_up2(x))\n",
    "        \n",
    "        x=self.upsample(x)\n",
    "        x = F.relu(self.conv_up3(x))\n",
    "        \n",
    "        x=self.upsample(x)\n",
    "        x = F.relu(self.conv_up4(x))\n",
    "\n",
    "        x = self.conv_out(x)\n",
    "        \n",
    "        return x "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "h,w=128,192\n",
    "params_model={\n",
    "        \"input_shape\": (1,h,w),\n",
    "        \"initial_filters\": 16, \n",
    "        \"num_outputs\": 1,\n",
    "            }\n",
    "\n",
    "model = SegNet(params_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "model=model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pylab as plt\n",
    "from PIL import Image\n",
    "from scipy import ndimage as ndi\n",
    "from skimage.segmentation import mark_boundaries\n",
    "\n",
    "def show_img_mask(img, mask):\n",
    "    \n",
    "    img_mask=mark_boundaries(np.array(img), \n",
    "                        np.array(mask),\n",
    "                        outline_color=(0,1,0),\n",
    "                        color=(0,1,0))\n",
    "    plt.imshow(img_mask)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SegNet(\n",
       "  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (upsample): Upsample(scale_factor=2.0, mode=bilinear)\n",
       "  (conv_up1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv_up2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv_up3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv_up4): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path2weights=\"./models/weights.pt\"\n",
    "model.load_state_dict(torch.load(path2weights))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'path2train' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m-----------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0mTraceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-d49490f6b8dd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrndImgs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0mpath2img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath2train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath2img\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0mimg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'path2train' is not defined"
     ]
    }
   ],
   "source": [
    "from torchvision.transforms.functional import to_tensor, to_pil_image\n",
    "\n",
    "for fn in rndImgs:\n",
    "    path2img = os.path.join(path2train, fn)\n",
    "    img = Image.open(path2img)\n",
    "    img=img.resize((w,h))\n",
    "    img_t=to_tensor(img).unsqueeze(0).to(device)\n",
    "    \n",
    "    pred=model(img_t)\n",
    "    pred=torch.sigmoid(pred)[0]\n",
    "    mask_pred= (pred[0]>=0.5)\n",
    "\n",
    "    plt.figure()\n",
    "    plt.subplot(1, 3, 1) \n",
    "    plt.imshow(img, cmap=\"gray\")\n",
    "\n",
    "    plt.subplot(1, 3, 2) \n",
    "    plt.imshow(mask_pred, cmap=\"gray\")\n",
    "    \n",
    "    plt.subplot(1, 3, 3) \n",
    "    show_img_mask(img, mask_pred)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
